from functools import reduce
import random

import simpy
import ipaddress

from .common import Event
from .fib import FibUpdateMessage, FibUpdate, DELETE, INSERT

WITHDRAW = 0
ANNOUNCE = 1

class BgpUpdate(object):
    def __init__(self, utype, prefix, path, **kwargs):
        self.utype = utype
        self.prefix = prefix
        self.path = path

    def __repr__(self):
        utype = 'WITHDRAW' if self.utype == WITHDRAW else 'ANNOUNCE'
        return '%s - %s: %s' % (utype, str(self.prefix), str(self.path))

class BgpMessage(Event):
    def __init__(self, timestamp, sender, updates):
        Event.__init__(self, timestamp, sender)
        self.updates = updates

    def __repr__(self):
        updates_str = ''
        for u in self.updates:
            updates_str += '\t%s\n' % (str(u))
        return '%.2f: from %s:\n%s' % (self.timestamp, self.sender,
                                       updates_str)

class BgpEntry(object):
    def __init__(self, src, next_hop, prefix, path):
        self.src = src
        self.next_hop = next_hop
        self.prefix = prefix
        self.path = path
        self.metric = len(path)

    def as_fib_str(self):
        s = ''
        s += '\t%s' % (str(self.prefix))
        s += '\t%s' % (self.next_hop)
        return s

class BgpEntryUpdate(object):
    def __init__(self, utype, bgp_entry):
        self.utype = utype
        self.entry = bgp_entry

    def as_fib_update(self):
        return FibUpdate(0, self.utype, {'dip': self.entry.prefix},
                         self.entry.prefix.prefixlen, self.entry.next_hop)

    def as_bgp_update(self, device_id):
        utype = WITHDRAW if self.utype == DELETE else ANNOUNCE
        return BgpUpdate(utype, str(self.entry.prefix),
                         self.entry.path + [device_id])

class BgpProcess(object):
    def __init__(self, env, device_id, **kwargs):
        self.env = env
        self.device_id = device_id
        self.peers = {}
        self.channel = simpy.Store(self.env, capacity=simpy.core.Infinity)
        self.attributes = kwargs

        self.rib = {}
        self.fib = {}
        self.announced = {}

    def add_peer(self, peer_id, **kwargs):
        self.peers[peer_id] = kwargs

        updates_ = [BgpEntryUpdate(INSERT, self.fib[p]) for p in self.fib]
        if len(updates_) > 0:
            bgp_updates = self.calculate_bgp_updates(updates_)
            channel = self.peers[peer_id]['channel']
            channel.put(BgpMessage(self.env.now, self.device_id,
                                   bgp_updates))

    def set_fpm(self, fpm):
        self.fpm = fpm

    def say_hello(self):
        env = self.env
        for p in self.peers:
            c = self.peers[p]['channel']
            yield env.timeout(random.randint(1, 3))
            c.put(Event(env.now, self.device_id))

    def update(self, sender, updates):
        updates_ = []
        for u in updates:
            prefix = ipaddress.ip_network(u.prefix)
            if self.device_id in u.path:
                continue
            if u.utype == WITHDRAW:
                ribs = self.rib[prefix]
                if sender in ribs:
                    del ribs[sender]
                fib = self.fib[prefix]
                if fib.src == sender:
                    updates_ += [BgpEntryUpdate(DELETE, fib)]
                    del self.fib[prefix]
                    if len(ribs) > 0:
                        opt = reduce(lambda x, y: x if x.metric < y.metric else y,
                                     ribs.values())
                        self.fib[prefix] = opt
                        updates_ += [BgpEntryUpdate(INSERT, opt)]
            else: #utype == ANNOUNCE
                if u.prefix not in self.rib:
                    self.rib[prefix] = {}
                entry = BgpEntry(sender, sender, prefix, u.path)
                self.rib[prefix] |= {sender: entry}
                fib = self.fib.get(prefix, None)
                if fib is None or entry.metric < fib.metric:
                    self.fib[prefix] = entry
                    updates_ += [BgpEntryUpdate(INSERT, entry)]
                    if fib is not None:
                        updates_ += [BgpEntryUpdate(DELETE, fib)]
        return updates_

    def calculate_bgp_updates(self, updates):
        return [u.as_bgp_update(self.device_id) for u in updates]

    def calculate_fib_updates(self, updates):
        return [u.as_fib_update() for u in updates]

    def main_loop(self):
        batch_id = 0
        while True:
            req = yield self.channel.get()

            if isinstance(req, BgpMessage):
                updates = self.update(req.sender, req.updates)

                if len(updates) > 0:
                    bgp_updates = self.calculate_bgp_updates(updates)
                    for peer in self.peers:
                        channel = self.peers[peer]['channel']
                        channel.put(BgpMessage(self.env.now, self.device_id,
                                               bgp_updates))
                    fib_updates = self.calculate_fib_updates(updates)
                    batch_id += 1
                    self.fpm.put(FibUpdateMessage(self.env.now, self.device_id, None,
                                                  batch_id, fib_updates))
                print(req, self.device_id)
            else:
                print(req, self.device_id)

    def dump_fib(self):
        s = 'FIB[%s]\n' % (self.device_id)
        for p in self.fib:
            s += '\t%s\n' % (self.fib[p].as_fib_str())
        print(s)
